import os
import re
from typing import List
import tiktoken
import matplotlib.pyplot as plt


heading_pattern = lambda x : re.compile(f'{"=" * x} .* {"=" * x}\n')
sub_heading_pattern = lambda x : re.compile(f'\n{"#" * x}|\n' + '\*' * x)
SEP_LIST = [heading_pattern(x) for x in range(2, 6)] + [re.compile(r"\n\n")] + [sub_heading_pattern(x) for x in range(1, 6)]

# Function to count tokens in a text using tiktoken
def count_tokens(text, model="gpt-4"):
    encoder = tiktoken.encoding_for_model(model)
    tokens = encoder.encode(text)
    return len(tokens)

# Main function to extract pattern content and split the text by pattern
def filter_text(text, pattern=r'\{\|.*?\|\}'):
    extracted_content_chunks = re.findall(pattern, text, re.DOTALL)
    split_text_chunks = re.split(pattern, text, flags=re.DOTALL)
    return extracted_content_chunks, split_text_chunks

def merge_chunks(chunks, max_tokens=128):
    merged_chunks = []
    current_chunk = ''
    for chunk in chunks:
        current_chunk += chunk
        if count_tokens(current_chunk) <= max_tokens:
            continue
        merged_chunks.append(current_chunk)
        current_chunk = ''
    return merged_chunks

# Function to split text into chunks based on subtitles marked by '== =='
def split_into_chunks(text, level=0):
    if level >= len(SEP_LIST):
        chunks = re.split('\n', text)
        merged_chunks = merge_chunks(chunks)
        return merged_chunks
    chunks = re.split(SEP_LIST[level], text)
    final_chunks = []
    for i, chunk in enumerate(chunks):
        if count_tokens(chunk) > 512:
            final_chunks.extend(split_into_chunks(chunk, level + 1))
        else:
            final_chunks.append(chunk)
    final_merged_chunks = merge_chunks(final_chunks)
    return final_merged_chunks

def split_into_chunks_by_tokens(text, max_tokens=4096):
    encoder = tiktoken.encoding_for_model("text-embedding-3-large")
    tokens = encoder.encode(text)

    if len(tokens) > max_tokens:
        # Split the text into smaller chunks
        chunks = [tokens[i:i + max_tokens] for i in range(0, len(tokens), max_tokens)]
        return [encoder.decode(chunk) for chunk in chunks]
    else:
        return [text]

def split_all(folder_path: str, show_stats=False, debug=False) -> List[str]:
    # Loop through each file in the folder
    max_tokens = 0
    token_counts = []
    all_chunks = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".txt"):
            file_path = os.path.join(folder_path, filename)
            
            # Read the content of the file
            with open(file_path, 'r', encoding='utf-8') as file:
                content = file.read()

            # Filter the content
            extracted_content_chunks, split_text_chunks = filter_text(content)
            
            # Split the content into chunks
            chunks = extracted_content_chunks
            for i, chunk in enumerate(split_text_chunks):
                chunks.extend(split_into_chunks(chunk))
            
            all_chunks.extend(chunks)
            
            if show_stats or debug:
                # Analyze the number of tokens in each chunk
                for i, chunk in enumerate(chunks):
                    token_count = count_tokens(chunk)
                    max_tokens = max(max_tokens, token_count)
                    token_counts.append(token_count)
                    if debug:
                        if count_tokens(chunk) > 8000:
                            print(f"Chunk {i}: {chunk}")
                            exit(0)

    if show_stats:
        print(f"Max Tokens: {max_tokens}")
        step = 50
        plt.hist(token_counts, bins=range(0, max(token_counts)+step, step), edgecolor='black')
        plt.title("Token Distribution in Chunks")
        plt.xlabel("Token Count")
        plt.ylabel("Number of Chunks")
        plt.show()

    return all_chunks


if __name__ == "__main__":
    # Path to the folder containing the text files
    folder_path = "./wiki_processed"
    
    # Split the files into chunks
    chunks = split_all(folder_path, show_stats=True, debug=False)
    from tqdm import tqdm
    for i, chunk in enumerate(tqdm(chunks)):
        pass

    print(len(chunks))

